import os
import zipfile
import src.freda.config as config
import src.freda.feature_selections as fs
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.model_selection import TimeSeriesSplit
from sklearn.ensemble import RandomForestClassifier
from sklearn.base import BaseEstimator
from sklearn.metrics import roc_auc_score

def handle_zip():
    """Unzips the FREDA data zip if not unzipped already"""
    if not os.path.exists(config.FILE_PROVES_FREDA_UNZIPPED):
        with zipfile.ZipFile(config.FREDA_ZIP, 'r') as zip_ref:
            zip_ref.extractall(config.SENSITIVE_DATA_DIR)

def get_stata_data(filepath):
    return pd.read_stata(filepath,convert_categoricals=False)

def get_anchors_and_paradata(directorypath):
    paradata = pd.DataFrame()
    anchors = pd.DataFrame()
    generator =  list(os.walk(directorypath))
    stem = generator[0][0]
    filenames = generator[0][2]
    for filename in filenames:
        filepath = stem + "\\"+filename
        if config.PARADATA_FILENAME_KEYWORD in filename:
            paradata = pd.concat([paradata,get_stata_data(filepath)])
        elif config.ANCHOR_FILENAME_KEYWORD in filename:
            anchors = pd.concat([anchors,get_stata_data(filepath)])
    return anchors, paradata


def get_raw_data():
    anchors, paradata = get_anchors_and_paradata(config.STATA_DIR)
    anchors.index = anchors.apply(lambda row: str(row['id'])+"_"+str(row['welle']),axis=1)
    paradata.index = paradata.apply(lambda row: row['id']+row['welle'],axis=1)
    data = anchors.join(paradata,rsuffix='paradata')
    return data


def drop_onevalue_columns(data:pd.DataFrame):
    drops = []
    for column in data.columns:
        if len(data[column].unique()) == 1:
            drops.append(column)
    return data.drop(drops,axis=1)


def remove_subzeros(data:pd.DataFrame):
    for column in data.columns:
        data.loc[data[column]<0,column] = np.nan
    return data

def is_id_in_next_wave(row,data):
        out = True
        df = data.loc[data['id']==row['id'],'welle']
        if row['welle'] + 1 in df.values:
            out = False
        return out

def add_nonresponse(data:pd.DataFrame):
    data['Nonresponse_Next_Wave'] = False
    data['Nonresponse_Next_Wave'] = data.apply(lambda row: is_id_in_next_wave(row,data),axis=1)
    return data

def get_X(data):
    con_cols = data.columns.intersection(fs.continuous)
    scaler = StandardScaler()
    cont = pd.DataFrame(scaler.fit_transform(data[con_cols]),index=data.index, columns=con_cols)

    cat_cols = data.columns.intersection(fs.categorical)
    ohe = OneHotEncoder()
    catg = pd.DataFrame(ohe.fit_transform(data[cat_cols]).toarray(),index=data.index,columns=ohe.get_feature_names_out())

    X = pd.concat([cont,catg],axis=1)
    return X

def get_y(data):
    return data[fs.dependant]